# This script performs feature selection using RFE to identify the CpGs which should be included in the newborn and childhood methylation risk scores (MRSs)
# CpG selection is detailed in: Genomic_risk_score_features.xlsx, sheet: CpGs included in newborn MRS / CpGs included in childhood MRS
# Python version 3.6.8 is used

# Feature Extraction with RFE
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import imblearn
from imblearn.ensemble import BalancedRandomForestClassifier
#import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.feature_selection import RFECV

###################
### NEWBORN MRS ###
###################
# Import the IOWBC methylation dataset for the newborn MRS candidate CpGs  - based on Reese et al's EWAS meta-analysis, 6 candidate CpGs were considered for inclusion in the newborn MRS. Detailed in: Genomic_risk_score_features.xlsx, sheet: CpGs included in newborn MRS
newborn = pd.read_csv("/scratch/dk2e18/IoW_Methylation_Data/Final_dataset/Beta_QN_autosome_combat_newborn_EWAS_6_CpGs_EPIC_862GU_asthma_747ID.csv", index_col=False) 	
del newborn['Unnamed: 0']

# Remove those with NA to identify individuals with complete data - there should be no NAs 
complete_newborn = newborn.dropna()

# Separate features and outcome for feature selection
X,Y=complete_newborn.iloc[:,1:complete_newborn.shape[1]-1],complete_newborn.iloc[:,complete_newborn.shape[1]-1]

# Standardise the CpGs
scaler = StandardScaler()
SX = pd.DataFrame(scaler.fit_transform(X), columns=('cg13289553', 'ch.6.1218502R', 'cg13427149', 'cg17333211', 'cg02331902','cg07156990'))

### Plot a correlation matrix to identify highly correlated CpGs ###
corr = SX.corr(method='spearman')

plt.figure()
plt.title("Heat map of correlation between newborn MRS CpGs")
# Generate a mask for the upper triangle
mask = np.triu(np.ones_like(corr, dtype=np.bool))
# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(15, 15))
# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=1.0, vmin=-1.0,center=0,
            square=True, linewidths=.5, cbar_kws={"shrink": .5})
plt.savefig('Spearman correlation of newborn MRS CpGs heatmap.pdf')


### Perform RFE - with the balanced random forest algorithm ###
# Define parameters for RFE - used default settings 
best_param1= {'bootstrap': True,'criterion': 'gini', 'max_depth': None, 'max_features': 'sqrt', 'min_samples_split': 2, 'n_estimators': 100}

# Define RFE model
bclf = BalancedRandomForestClassifier(n_estimators=best_param1["n_estimators"],max_depth=best_param1["max_depth"],
                              min_samples_split =best_param1["min_samples_split"],max_features=best_param1["max_features"],random_state=123)

rfecv = RFECV(bclf, step=1, cv=StratifiedKFold(5,random_state=123), scoring='balanced_accuracy')

# apply RFE to data
fit=rfecv.fit(SX, Y)

list = []
for i in range(0, 6):
	if rfecv.ranking_[i] == 1:
		list.append(SX.columns.values[i])

print("Optimal number of features : %d" % rfecv.n_features_)
# 6 features

print("Accuracy: \n", rfecv.grid_scores_[5]) # n features - 1 for indexing	

print("Feature Selected: \n",list)

# Plot number of features against cross-validation scores
plt.figure()
plt.xlabel("Number of features selected")
plt.ylabel("Cross-validation balanced accuracy score")
plt.plot(range(1, len(rfecv.grid_scores_) + 1), rfecv.grid_scores_)
plt.savefig('Feature_selection_balancedRFE_newborn_mrs.pdf')


#####################
### CHILDHOOD MRS ###
#####################
# Import the IOWBC methylation dataset for the childhood MRS candidate CpGs  - based on Reese et al's EWAS meta-analysis, 157 candidate CpGs were considered for inclusion in the childhood MRS. Detailed in: Genomic_risk_score_features.xlsx, sheet: CpGs included in childhood MRS
childhood = pd.read_csv("/scratch/dk2e18/IoW_Methylation_Data/Final_dataset/Beta_QN_autosome_combat_childhood_EWAS_157_CpGs_EPIC_862GU_asthma_747ID.csv", index_col=False) 	
del childhood['Unnamed: 0']

# Remove those with NA to identify individuals with complete data - there should be no NAs 
complete_childhood = childhood.dropna()

# Separate features and outcome for feature selection
X,Y=complete_childhood.iloc[:,1:complete_childhood.shape[1]-1],complete_childhood.iloc[:,complete_childhood.shape[1]-1]
columns = X.columns

# Standardise the CpGs
scaler = StandardScaler()
SX = pd.DataFrame(scaler.fit_transform(X), columns=columns)


### Plot a correlation matrix to identify highly correlated CpGs ###
corr = SX.corr(method='spearman')

plt.figure()
plt.title("Heat map of correlation between childhood MRS CpGs")
# Generate a mask for the upper triangle
mask = np.triu(np.ones_like(corr, dtype=np.bool))
# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(15, 15))
# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=1.0, vmin=-1.0, center=0,
            square=True, linewidths=.5, cbar_kws={"shrink": .5})
plt.savefig('Spearman correlation of childhood MRS CpGs heatmap.pdf')

# Remove CpGs that are highly corrlated - in our case, we removed data for cg08640475. 

### Perform RFE - with the balanced random forest algorithm ###
# Define parameters for RFE - used default settings 
best_param1= {'bootstrap': True,'criterion': 'gini', 'max_depth': None, 'max_features': 'sqrt', 'min_samples_split': 2, 'n_estimators': 100}

# Define RFE model
bclf = BalancedRandomForestClassifier(n_estimators=best_param1["n_estimators"],max_depth=best_param1["max_depth"],
                              min_samples_split =best_param1["min_samples_split"],max_features=best_param1["max_features"],random_state=123)

rfecv = RFECV(bclf, step=1, cv=StratifiedKFold(5,random_state=123), scoring='balanced_accuracy')

# apply RFE to data
fit=rfecv.fit(SX, Y)

list = []
for i in range(0, 156):
	if rfecv.ranking_[i] == 1:
		list.append(SX.columns.values[i])

print("Optimal number of features : %d" % rfecv.n_features_)
# 110 features

print("Accuracy: \n", rfecv.grid_scores_[109]) # n features - 1 for indexing	

print("Feature Selected: \n",list)


# Plot number of features against cross-validation scores
plt.figure()
plt.xlabel("Number of features selected")
plt.ylabel("Cross-validation balanced accuracy score")
plt.plot(range(1, len(rfecv.grid_scores_) + 1), rfecv.grid_scores_)
plt.ylim(0.42,0.60)
plt.savefig('Feature_selection_balancedRFE_childhood_mrs.pdf')